Working with Samples of Distributions over Convolutional Kernels

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt

from tensorflow.examples.tutorials.mnist import input_data as mnist_data
In [3]:
tf.__version__
Out[3]:
'1.2.1'
In [4]:
sess = tf.InteractiveSession()
In [5]:
mnist = mnist_data.read_data_sets("/home/tiao/Desktop/MNIST")
Extracting /home/tiao/Desktop/MNIST/train-images-idx3-ubyte.gz
Extracting /home/tiao/Desktop/MNIST/train-labels-idx1-ubyte.gz
Extracting /home/tiao/Desktop/MNIST/t10k-images-idx3-ubyte.gz
Extracting /home/tiao/Desktop/MNIST/t10k-labels-idx1-ubyte.gz
In [6]:
# 50 single-channel (grayscale) 28x28 images
x = mnist.train.images[:50].reshape(-1, 28, 28, 1)
x.shape
Out[6]:
(50, 28, 28, 1)
In [7]:
fig, ax = plt.subplots(figsize=(5, 5))

# showing an arbitrarily chosen image
ax.imshow(np.squeeze(x[5], axis=-1), cmap='gray')

plt.show()

Standard 2D Convolution with conv2d

In [8]:
# 32 kernels of size 5x5x1
kernel = tf.truncated_normal([5, 5, 1, 32], stddev=0.1)
kernel.get_shape().as_list()
Out[8]:
[5, 5, 1, 32]
In [9]:
x_conved = tf.nn.conv2d(x, kernel, 
                        strides=[1, 1, 1, 1], 
                        padding='SAME')
x_conved.get_shape().as_list()
Out[9]:
[50, 28, 28, 32]
In [10]:
x_conved[5, ..., 0].eval().shape
Out[10]:
(28, 28)
In [11]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

# showing what the 0th filter looks like
ax1.imshow(kernel[..., 0, 0].eval(), cmap='gray')

# show the previous arbitrarily chosen image
# convolved with the 0th filter
ax2.imshow(x_conved[5, ..., 0].eval(), cmap='gray')

plt.show()

Sample from a Distribution over Kernels

In [12]:
# 8x32 kernels of size 5x5x1
kernels = tf.truncated_normal([8, 5, 5, 1, 32], stddev=0.1)
kernels.get_shape().as_list()
Out[12]:
[8, 5, 5, 1, 32]

Approach 1: Map over samples with conv2d

In [13]:
x_tiled = tf.tile(tf.expand_dims(x, 0), [8, 1, 1, 1, 1])
x_tiled.get_shape().as_list()
Out[13]:
[8, 50, 28, 28, 1]
In [19]:
tf.nn.conv2d(x_tiled[0], kernels[0], 
             strides=[1, 1, 1, 1], 
             padding='SAME').get_shape().as_list()
Out[19]:
[50, 28, 28, 32]
In [15]:
x_conved1 = tf.map_fn(lambda args: tf.nn.conv2d(*args, strides=[1, 1, 1, 1], padding='SAME'),
                      elems=(x_tiled, kernels), dtype=tf.float32)
x_conved1.get_shape().as_list()
Out[15]:
[8, 50, 28, 28, 32]

Approach 2: Flattening

In [16]:
kernels_flat = tf.reshape(tf.transpose(kernels, 
                                       perm=(1, 2, 3, 4, 0)), 
                          shape=(5, 5, 1, 32*8))
kernels_flat.get_shape().as_list()
Out[16]:
[5, 5, 1, 256]
In [17]:
x_conved2 = tf.transpose(tf.reshape(tf.nn.conv2d(x, kernels_flat, 
                                                 strides=[1, 1, 1, 1], 
                                                 padding='SAME'), 
                                    shape=(50, 28, 28, 32, 8)), 
                         perm=(4, 0, 1, 2, 3))
x_conved2.get_shape().as_list()
Out[17]:
[8, 50, 28, 28, 32]
In [18]:
tf.reduce_all(tf.equal(x_conved1, x_conved2)).eval()
Out[18]:
True

Variational Inference with Implicit Approximate Inference Models - @fhuszar's Explaining Away Example Pt. 1 (WIP)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [70]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Add, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60

Bayesian Logistic Regression (Synthetic Data)

In [7]:
z_min, z_max = -5, 5
In [8]:
z1, z2 = np.mgrid[z_min:z_max:300j, z_min:z_max:300j]
In [9]:
z_grid = np.dstack((z1, z2))
z_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(z_grid)
log_prior.shape
Out[11]:
(300, 300)
In [13]:
np.allclose(log_prior, 
            -.5*np.sum(z_grid**2, axis=2)/PRIOR_VARIANCE \
            -np.log(2*np.pi*PRIOR_VARIANCE))
Out[13]:
True
In [15]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(z1, z2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(z_min, z_max)
ax.set_ylim(z_min, z_max)

plt.show()
In [16]:
x = np.array([0, 5, 8, 12, 50])
In [37]:
def log_likelihood(z, x, beta_0=3., beta_1=1.):
    beta = beta_0 + np.sum(beta_1*np.maximum(0, z**3), axis=-1)
    return -np.log(beta) - x/beta
In [44]:
llhs = log_likelihood(z_grid, x.reshape(-1, 1, 1))
llhs.shape
Out[44]:
(5, 300, 300)
In [59]:
fig, axes = plt.subplots(ncols=len(x), nrows=1, figsize=(20, 4))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(z1, z2, llhs[i,::,::], cmap=plt.cm.magma)

    ax.set_xlim(z_min, z_max)
    ax.set_ylim(z_min, z_max)
    
    ax.set_title('$p(x = {{{0}}} \mid z)$'.format(x[i]))
    ax.set_xlabel('$z_1$')    
    
    if not i:
        ax.set_ylabel('$z_2$')

plt.show()
In [60]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(z1, z2, np.sum(llhs, axis=0), 
                cmap=plt.cm.magma)

ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')

ax.set_xlim(z_min, z_max)
ax.set_ylim(z_min, z_max)

plt.show()
In [61]:
fig, axes = plt.subplots(ncols=len(x), nrows=1, figsize=(20, 4))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(z1, z2,  np.exp(log_prior+llhs[i,::,::]), 
                cmap='magma')

    ax.set_xlim(z_min, z_max)
    ax.set_ylim(z_min, z_max)
    
    ax.set_title('$Zp(z \mid x = {{{0}}})$'.format(x[i]))
    ax.set_xlabel('$z_1$')    
    
    if not i:
        ax.set_ylabel('$z_2$')

plt.show()
In [63]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(z1, z2, 
            np.exp(log_prior+np.sum(llhs, axis=0)), 
            cmap='magma')

ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')

ax.set_xlim(z_min, z_max)
ax.set_ylim(z_min, z_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

In [84]:
x_input = Input(shape=(1,), name='x')
x_hidden = Dense(10, activation='relu')(x_input)
x_hidden = Dense(20, activation='relu')(x_hidden)
In [85]:
z_input = Input(shape=(LATENT_DIM,), name='z')
z_hidden = Dense(10, activation='relu')(z_input)
z_hidden = Dense(20, activation='relu')(z_hidden)
In [86]:
discrim_hidden = Add()([x_hidden, z_hidden])
discrim_hidden = Dense(10, activation='relu')(discrim_hidden)
discrim_hidden = Dense(20, activation='relu')(discrim_hidden)
discrim_logit = Dense(1, activation=None, 
                       name='logit')(discrim_hidden)
discrim_out = Activation('sigmoid')(discrim_logit)
In [87]:
discriminator = Model(inputs=[x_input, z_input], outputs=discrim_out)
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [88]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discrim_logit)
In [89]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[89]:
G 140115456278480 x: InputLayerinput:output:(None, 1)(None, 1)140115456277696 dense_17: Denseinput:output:(None, 1)(None, 10)140115456278480->140115456277696 140115452033176 z: InputLayerinput:output:(None, 2)(None, 2)140115453473064 dense_19: Denseinput:output:(None, 2)(None, 10)140115452033176->140115453473064 140115453258048 dense_18: Denseinput:output:(None, 10)(None, 20)140115456277696->140115453258048 140115452043616 dense_20: Denseinput:output:(None, 10)(None, 20)140115453473064->140115452043616 140115531775392 add_8: Addinput:output:[(None, 20), (None, 20)](None, 20)140115453258048->140115531775392 140115452043616->140115531775392 140115531775168 dense_21: Denseinput:output:(None, 20)(None, 10)140115531775392->140115531775168 140115531775504 dense_22: Denseinput:output:(None, 10)(None, 20)140115531775168->140115531775504 140115543993984 logit: Denseinput:output:(None, 20)(None, 1)140115531775504->140115543993984 140115453098640 activation_2: Activationinput:output:(None, 1)(None, 1)140115543993984->140115453098640
In [105]:
np.ones((32, 5)) + np.ones((16, 5))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-105-a367db8def8f> in <module>()
----> 1 np.ones((32, 5)) + np.ones((16, 5))

ValueError: operands could not be broadcast together with shapes (32,5) (16,5) 
In [112]:
z_grid_ratio = ratio_estimator.predict([np.ones((16, 1)), np.ones((32, 2))])
z_grid_ratio.shape
Out[112]:
(16, 1)

Initial density ratio, prior to any training

In [ ]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [ ]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [ ]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [ ]:
phi = inference.trainable_weights
phi
In [ ]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
In [ ]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
In [ ]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
In [ ]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [ ]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [ ]:
metrics = discriminator.evaluate(inputs, targets)
In [ ]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [ ]:
metrics
In [ ]:
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
In [ ]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

metrics_plots = {k:ax1.plot([], label=k)[0] 
                 for k in ['loss']} # discriminator.metrics_names}

ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [ ]:
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):

    # Single training epoch
    
    for step in tnrange(steps_per_epoch, unit='step', leave=False):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    # Plot Metrics
        
    metrics_dict = dict(zip(discriminator.metrics_names, metrics))

    for metric in metrics_plots:
        metrics_plots[metric].set_xdata(np.append(metrics_plots[metric].get_xdata(), 
                                                  epoch_num))    
        metrics_plots[metric].set_ydata(np.append(metrics_plots[metric].get_ydata(), 
                                                  metrics_dict[metric]))
        metrics_plots[metric].set_label('{} ({:.2f})' \
                                        .format(metric, 
                                                metrics_dict[metric]))
    
    ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
    ax1.legend(loc='upper left')

    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(**metrics_dict)

    return list(metrics_plots.values())
In [ ]:
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate` 
# function that encapsulates the logic of single
# training epoch. Has benefit of producing 
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=PRETRAIN_EPOCHS,
                         interval=200, # 5 fps
                         blit=True)

    anim_html5_video = anim.to_html5_video()
In [ ]:
HTML(anim_html5_video)
In [ ]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [ ]:
metrics = discriminator.evaluate(inputs, targets)
In [ ]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [ ]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

metrics_dict = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**metrics_dict), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [ ]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [ ]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
In [ ]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
In [ ]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [ ]:
sess = K.get_session()
In [ ]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
In [ ]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Reweight likelihood term!

In [ ]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(2.*log_likelihood-kl_estimate, axis=-1)

    return elbo
In [ ]:
elbo = make_elbo(ratio_estimator)
In [ ]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [ ]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [ ]:
inference.compile(loss=inference_loss, 
                  optimizer=Adam(lr=LEARNING_RATE))
In [ ]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [ ]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
In [ ]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
In [ ]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))

Adversarial Training

In [ ]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

global_epoch = 0

loss_plot_inference, = ax1.plot([], label='inference')
loss_plot_discrim, = ax1.plot([], label='discriminator')

ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [ ]:
def train_animate(epoch_num, prog_bar, batch_size=200, 
                  steps_per_epoch=15):

    global global_epoch, loss_plot_inference, loss_plot_discrim
    
    # Single training epoch

    ## Ratio estimator training
        
    set_trainable(discriminator, True)

    for _ in tnrange(3*50, unit='step', desc='discriminator', 
                     leave=False):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)

    metrics_dict_discrim = dict(zip(discriminator.metrics_names, 
                                    np.atleast_1d(metrics_discrim)))
    
    ## Inference model training
    
    set_trainable(ratio_estimator, False)

    y_tiled = np.tile(y, reps=(BATCH_SIZE, 1))

    for _ in tnrange(1, unit='step', desc='inference', leave=False):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, y_tiled)
        
    metrics_dict_inference = dict(zip(inference.metrics_names, 
                                      np.atleast_1d(metrics_inference)))

    global_epoch += 1
    
    # Plot Loss
 
    loss_plot_inference.set_xdata(np.append(loss_plot_inference.get_xdata(),
                                            global_epoch))
    loss_plot_inference.set_ydata(np.append(loss_plot_inference.get_ydata(), 
                                            metrics_dict_inference['loss']))

    loss_plot_inference.set_label('inference ({:.2f})' \
                                  .format(metrics_dict_inference['loss']))

    loss_plot_discrim.set_xdata(np.append(loss_plot_discrim.get_xdata(),
                                          global_epoch))
    loss_plot_discrim.set_ydata(np.append(loss_plot_discrim.get_ydata(),
                                          metrics_dict_discrim['loss']))

    loss_plot_discrim.set_label('discriminator ({:.2f})' \
                                  .format(metrics_dict_discrim['loss']))
    
    ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
    ax1.legend(loc='upper left')
    
    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(loss_inference=metrics_dict_inference['loss'],
                         loss_discriminator=metrics_dict_discrim['loss'])

    return loss_plot_inference, loss_plot_discrim
In [ ]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
In [ ]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
In [ ]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
In [ ]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)

Evaluating the model

In [ ]:
w_sample_prior = prior.rvs(size=128)
eps = np.random.randn(256, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(128), np.ones(256)))
In [ ]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [ ]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax1.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax1.set_xlabel('$w_1$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

ax2.contourf(w1, w2, np.sum(llhs, axis=2), 
             cmap=plt.cm.magma)
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [ ]:
eps = np.random.randn(5000, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
In [ ]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax1.scatter(*inference.predict(eps[::10]).T, 
            s=4.**2, alpha=.6, cmap='coolwarm_r')

ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

sns.kdeplot(*inference.predict(eps).T,
            cmap='magma', ax=ax2)

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 9)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
Using TensorFlow backend.
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
log_prior = -np.sum(w_grid**2, axis=2)/2/PRIOR_VARIANCE
log_prior.shape
Out[12]:
(300, 300)
In [13]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [14]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([  .5, -1.])
In [15]:
X = np.vstack((x1, x2, x3))
X.shape
Out[15]:
(3, 2)
In [16]:
y1 = 1
y2 = 1
y3 = 0
In [17]:
y = np.stack((y1, y2, y3))
y.shape
Out[17]:
(3,)
In [18]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [19]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[19]:
(300, 300, 3)
In [20]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [22]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [23]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [24]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [25]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[25]:
G 139900251159464 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)139900250690168 dense_1: Denseinput:output:(None, 2)(None, 10)139900251159464->139900250690168 139900251158064 dense_2: Denseinput:output:(None, 10)(None, 20)139900250690168->139900251158064 139900251160192 logit: Denseinput:output:(None, 20)(None, 1)139900251158064->139900251160192 139900251182920 activation_1: Activationinput:output:(None, 1)(None, 1)139900251160192->139900251182920
In [26]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [27]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [28]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[28]:
[0.53340649604797363, 1.0]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [29]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [30]:
phi = inference.trainable_weights
phi
Out[30]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [31]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[31]:
G 139900250463256 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)139900250462024 dense_3: Denseinput:output:(None, 3)(None, 10)139900250463256->139900250462024 139900250461520 dense_4: Denseinput:output:(None, 10)(None, 20)139900250462024->139900250461520 139900251331832 dense_5: Denseinput:output:(None, 20)(None, 2)139900250461520->139900251331832
In [32]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[32]:
(200, 2)
In [33]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[33]:
(200, 2)
In [34]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [35]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [36]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [37]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [38]:
metrics
Out[38]:
[0.65425197124481205, 0.58999999999999997]
In [39]:
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
In [40]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

metrics_plots = {k:ax1.plot([], label=k)[0] 
                 for k in ['loss']} # discriminator.metrics_names}

ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [41]:
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):

    # Single training epoch
    
    for step in tnrange(steps_per_epoch, unit='step', leave=False):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    # Plot Metrics
        
    metrics_dict = dict(zip(discriminator.metrics_names, metrics))

    for metric in metrics_plots:
        metrics_plots[metric].set_xdata(np.append(metrics_plots[metric].get_xdata(), 
                                                  epoch_num))    
        metrics_plots[metric].set_ydata(np.append(metrics_plots[metric].get_ydata(), 
                                                  metrics_dict[metric]))
        metrics_plots[metric].set_label('{} ({:.2f})' \
                                        .format(metric, 
                                                metrics_dict[metric]))
    
    ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
    ax1.legend(loc='upper left')

    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(**metrics_dict)

    return list(metrics_plots.values())
In [42]:
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate` 
# function that encapsulates the logic of single
# training epoch. Has benefit of producing 
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=PRETRAIN_EPOCHS,
                         interval=200, # 5 fps
                         blit=True)

    anim_html5_video = anim.to_html5_video()

In [43]:
HTML(anim_html5_video)
Out[43]:
In [44]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [45]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [46]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [47]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

metrics_dict = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**metrics_dict), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [48]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [49]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[49]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [50]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[50]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [51]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [52]:
sess = K.get_session()
In [53]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[53]:
True
In [54]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Reweight likelihood term!

In [55]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(2.*log_likelihood-kl_estimate, axis=-1)

    return elbo
In [56]:
elbo = make_elbo(ratio_estimator)
In [57]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [58]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [59]:
inference.compile(loss=inference_loss, 
                  optimizer=Adam(lr=LEARNING_RATE))
In [60]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [61]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[61]:
<tf.Tensor 'concat:0' shape=(200, 3) dtype=float32>
In [62]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[62]:
-4.4352226
In [63]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/200 [===>..........................] - ETA: 0s
Out[63]:
4.4352222633361817

Adversarial Training

In [64]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

global_epoch = 0

loss_plot_inference, = ax1.plot([], label='inference')
loss_plot_discrim, = ax1.plot([], label='discriminator')

ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [121]:
def train_animate(epoch_num, prog_bar, batch_size=200, 
                  steps_per_epoch=15):

    global global_epoch, loss_plot_inference, loss_plot_discrim
    
    # Single training epoch

    ## Ratio estimator training
        
    set_trainable(discriminator, True)

    for _ in tnrange(3*50, unit='step', desc='discriminator', 
                     leave=False):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)

    metrics_dict_discrim = dict(zip(discriminator.metrics_names, 
                                    np.atleast_1d(metrics_discrim)))
    
    ## Inference model training

    print('Before:')
    print(ratio_estimator.get_weights()[0])
    print(inference.get_weights()[0])
    
    set_trainable(ratio_estimator, False)

    y_tiled = np.tile(y, reps=(BATCH_SIZE, 1))

    for _ in tnrange(1, unit='step', desc='inference', leave=False):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, y_tiled)

    print('After:')
    print(ratio_estimator.get_weights()[0])
    print(inference.get_weights()[0])
        
    print('###')
        
    metrics_dict_inference = dict(zip(inference.metrics_names, 
                                      np.atleast_1d(metrics_inference)))

    global_epoch += 1
    
    # Plot Loss
 
    loss_plot_inference.set_xdata(np.append(loss_plot_inference.get_xdata(),
                                            global_epoch))
    loss_plot_inference.set_ydata(np.append(loss_plot_inference.get_ydata(), 
                                            metrics_dict_inference['loss']))

    loss_plot_inference.set_label('inference ({:.2f})' \
                                  .format(metrics_dict_inference['loss']))

    loss_plot_discrim.set_xdata(np.append(loss_plot_discrim.get_xdata(),
                                          global_epoch))
    loss_plot_discrim.set_ydata(np.append(loss_plot_discrim.get_ydata(),
                                          metrics_dict_discrim['loss']))

    loss_plot_discrim.set_label('discriminator ({:.2f})' \
                                  .format(metrics_dict_discrim['loss']))
    
    ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
    ax1.legend(loc='upper left')
    
    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(loss_inference=metrics_dict_inference['loss'],
                         loss_discriminator=metrics_dict_discrim['loss'])

    return loss_plot_inference, loss_plot_discrim
In [122]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Before:
[[ 0.4   0.69 -0.56 -0.45  0.03  0.23 -0.67 -0.19  0.21  0.71]
 [-0.33 -0.29  0.22  0.54 -0.58 -1.03 -0.29 -0.75  0.8  -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.02  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.21 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.21 -0.67 -0.61  0.53]]
After:
[[ 0.4   0.69 -0.56 -0.45  0.03  0.23 -0.67 -0.19  0.21  0.71]
 [-0.33 -0.29  0.22  0.54 -0.58 -1.03 -0.29 -0.75  0.8  -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.02  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.21 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.21 -0.67 -0.61  0.53]]
###
Before:
[[ 0.43  0.7  -0.54 -0.43  0.04  0.23 -0.68 -0.18  0.25  0.71]
 [-0.32 -0.28  0.18  0.53 -0.56 -1.02 -0.29 -0.75  0.8  -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.02  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.21 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.21 -0.67 -0.61  0.53]]
After:
[[ 0.43  0.7  -0.54 -0.43  0.04  0.23 -0.68 -0.18  0.25  0.71]
 [-0.32 -0.28  0.18  0.53 -0.56 -1.02 -0.29 -0.75  0.8  -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.21 -0.67 -0.61  0.53]]
###
Before:
[[ 0.43  0.7  -0.52 -0.42  0.05  0.23 -0.68 -0.18  0.24  0.7 ]
 [-0.34 -0.28  0.2   0.55 -0.54 -1.03 -0.29 -0.74  0.79 -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.21 -0.67 -0.61  0.53]]
After:
[[ 0.43  0.7  -0.52 -0.42  0.05  0.23 -0.68 -0.18  0.24  0.7 ]
 [-0.34 -0.28  0.2   0.55 -0.54 -1.03 -0.29 -0.74  0.79 -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.21 -0.67 -0.61  0.53]]
###
Before:
[[ 0.42  0.71 -0.54 -0.41  0.07  0.23 -0.67 -0.13  0.25  0.71]
 [-0.32 -0.28  0.18  0.55 -0.58 -1.02 -0.3  -0.79  0.81 -0.23]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.21 -0.67 -0.61  0.53]]
After:
[[ 0.42  0.71 -0.54 -0.41  0.07  0.23 -0.67 -0.13  0.25  0.71]
 [-0.32 -0.28  0.18  0.55 -0.58 -1.02 -0.3  -0.79  0.81 -0.23]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.21 -0.67 -0.61  0.53]]
###
Before:
[[ 0.39  0.7  -0.52 -0.4   0.06  0.23 -0.67 -0.15  0.27  0.7 ]
 [-0.36 -0.28  0.2   0.52 -0.56 -1.03 -0.3  -0.74  0.8  -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.21 -0.67 -0.61  0.53]]
After:
[[ 0.39  0.7  -0.52 -0.4   0.06  0.23 -0.67 -0.15  0.27  0.7 ]
 [-0.36 -0.28  0.2   0.52 -0.56 -1.03 -0.3  -0.74  0.8  -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.22 -0.67 -0.61  0.53]]
###
Before:
[[ 0.37  0.67 -0.54 -0.44  0.05  0.24 -0.68 -0.15  0.27  0.7 ]
 [-0.39 -0.32  0.2   0.52 -0.56 -1.04 -0.3  -0.73  0.81 -0.21]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.78  0.12  0.65 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.12 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.22 -0.67 -0.61  0.53]]
After:
[[ 0.37  0.67 -0.54 -0.44  0.05  0.24 -0.68 -0.15  0.27  0.7 ]
 [-0.39 -0.32  0.2   0.52 -0.56 -1.04 -0.3  -0.73  0.81 -0.21]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.77  0.12  0.65 -0.31]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.13 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.22 -0.67 -0.61  0.53]]
###
Before:
[[ 0.39  0.68 -0.51 -0.42  0.05  0.23 -0.68 -0.17  0.27  0.71]
 [-0.37 -0.28  0.19  0.5  -0.54 -1.03 -0.3  -0.74  0.83 -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.77  0.12  0.65 -0.31]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.13 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.22 -0.67 -0.61  0.53]]
After:
[[ 0.39  0.68 -0.51 -0.42  0.05  0.23 -0.68 -0.17  0.27  0.71]
 [-0.37 -0.28  0.19  0.5  -0.54 -1.03 -0.3  -0.74  0.83 -0.22]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.77  0.12  0.65 -0.31]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.13 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.22 -0.67 -0.61  0.53]]
###
Before:
[[ 0.38  0.7  -0.56 -0.44  0.03  0.21 -0.67 -0.15  0.23  0.71]
 [-0.36 -0.27  0.2   0.51 -0.55 -1.01 -0.29 -0.74  0.81 -0.2 ]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.77  0.12  0.65 -0.31]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.13 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.22 -0.67 -0.61  0.53]]
After:
[[ 0.38  0.7  -0.56 -0.44  0.03  0.21 -0.67 -0.15  0.23  0.71]
 [-0.36 -0.27  0.2   0.51 -0.55 -1.01 -0.29 -0.74  0.81 -0.2 ]]
[[-0.33 -0.64  0.37 -0.19 -0.02  0.01  0.77  0.12  0.66 -0.3 ]
 [ 0.41 -0.5   0.5   0.17 -0.22 -0.41  0.61 -0.13 -0.13 -0.68]
 [ 0.71 -0.47 -0.69 -0.42 -0.21  0.41  0.22 -0.67 -0.61  0.53]]
###

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-122-ab938e4e4048> in <module>()
      9                          blit=True)
     10 
---> 11     anim_html5_video = anim.to_html5_video()
     12 
     13 HTML(anim_html5_video)

~/.virtualenvs/implicit/lib/python3.5/site-packages/matplotlib/animation.py in to_html5_video(self)
   1207                                 bitrate=rcParams['animation.bitrate'],
   1208                                 fps=1000. / self._interval)
-> 1209                 self.save(f.name, writer=writer)
   1210 
   1211             # Now open and base64 encode

~/.virtualenvs/implicit/lib/python3.5/site-packages/matplotlib/animation.py in save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs)
   1060                     for anim, d in zip(all_anim, data):
   1061                         # TODO: See if turning off blit is really necessary
-> 1062                         anim._draw_next_frame(d, blit=False)
   1063                     writer.grab_frame(**savefig_kwargs)
   1064 

~/.virtualenvs/implicit/lib/python3.5/site-packages/matplotlib/animation.py in _draw_next_frame(self, framedata, blit)
   1097         # post- draw, as well as the drawing of the frame itself.
   1098         self._pre_draw(framedata, blit)
-> 1099         self._draw_frame(framedata)
   1100         self._post_draw(framedata, blit)
   1101 

~/.virtualenvs/implicit/lib/python3.5/site-packages/matplotlib/animation.py in _draw_frame(self, framedata)
   1560         # Call the func with framedata and args. If blitting is desired,
   1561         # func needs to return a sequence of any artists that were modified.
-> 1562         self._drawn_artists = self._func(framedata, *self._args)
   1563         if self._blit:
   1564             if self._drawn_artists is None:

<ipython-input-121-78d9a9e0159c> in train_animate(epoch_num, prog_bar, batch_size, steps_per_epoch)
     78     ax2.cla()
     79 
---> 80     w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
     81     w_grid_ratio = w_grid_ratio.reshape(300, 300)
     82 

~/.virtualenvs/implicit/lib/python3.5/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose)
   1515         f = self.predict_function
   1516         return self._predict_loop(f, ins,
-> 1517                                   batch_size=batch_size, verbose=verbose)
   1518 
   1519     def train_on_batch(self, x, y,

~/.virtualenvs/implicit/lib/python3.5/site-packages/keras/engine/training.py in _predict_loop(self, f, ins, batch_size, verbose)
   1139                 ins_batch = _slice_arrays(ins, batch_ids)
   1140 
-> 1141             batch_outs = f(ins_batch)
   1142             if not isinstance(batch_outs, list):
   1143                 batch_outs = [batch_outs]

~/.virtualenvs/implicit/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2263                 value = (indices, sparse_coo.data, sparse_coo.shape)
   2264             feed_dict[tensor] = value
-> 2265         session = get_session()
   2266         updated = session.run(self.outputs + [self.updates_op],
   2267                               feed_dict=feed_dict,

~/.virtualenvs/implicit/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in get_session()
    166     if not _MANUAL_VAR_INIT:
    167         with session.graph.as_default():
--> 168             _initialize_variables()
    169     return session
    170 

~/.virtualenvs/implicit/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in _initialize_variables()
    334     uninitialized_variables = []
    335     for v in variables:
--> 336         if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
    337             uninitialized_variables.append(v)
    338             v._keras_initialized = True

KeyboardInterrupt: 
In [67]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)

Out[67]:
In [68]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-68-ab938e4e4048> in <module>()
      9                          blit=True)
     10 
---> 11     anim_html5_video = anim.to_html5_video()
     12 
     13 HTML(anim_html5_video)

~/.virtualenvs/implicit/lib/python3.5/site-packages/matplotlib/animation.py in to_html5_video(self)
   1207                                 bitrate=rcParams['animation.bitrate'],
   1208                                 fps=1000. / self._interval)
-> 1209                 self.save(f.name, writer=writer)
   1210 
   1211             # Now open and base64 encode

~/.virtualenvs/implicit/lib/python3.5/site-packages/matplotlib/animation.py in save(self, filename, writer, fps, dpi, codec, bitrate, extra_args, metadata, extra_anim, savefig_kwargs)
   1060                     for anim, d in zip(all_anim, data):
   1061                         # TODO: See if turning off blit is really necessary
-> 1062                         anim._draw_next_frame(d, blit=False)
   1063                     writer.grab_frame(**savefig_kwargs)
   1064 

~/.virtualenvs/implicit/lib/python3.5/site-packages/matplotlib/animation.py in _draw_next_frame(self, framedata, blit)
   1097         # post- draw, as well as the drawing of the frame itself.
   1098         self._pre_draw(framedata, blit)
-> 1099         self._draw_frame(framedata)
   1100         self._post_draw(framedata, blit)
   1101 

~/.virtualenvs/implicit/lib/python3.5/site-packages/matplotlib/animation.py in _draw_frame(self, framedata)
   1560         # Call the func with framedata and args. If blitting is desired,
   1561         # func needs to return a sequence of any artists that were modified.
-> 1562         self._drawn_artists = self._func(framedata, *self._args)
   1563         if self._blit:
   1564             if self._drawn_artists is None:

<ipython-input-65-5bea53f24968> in train_animate(epoch_num, prog_bar, batch_size, steps_per_epoch)
     16 
     17         eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
---> 18         w_sample_posterior = inference.predict(eps)
     19 
     20         inputs = np.vstack((w_sample_prior, w_sample_posterior))

~/.virtualenvs/implicit/lib/python3.5/site-packages/keras/models.py in predict(self, x, batch_size, verbose)
    907         if self.model is None:
    908             self.build()
--> 909         return self.model.predict(x, batch_size=batch_size, verbose=verbose)
    910 
    911     def predict_on_batch(self, x):

~/.virtualenvs/implicit/lib/python3.5/site-packages/keras/engine/training.py in predict(self, x, batch_size, verbose)
   1515         f = self.predict_function
   1516         return self._predict_loop(f, ins,
-> 1517                                   batch_size=batch_size, verbose=verbose)
   1518 
   1519     def train_on_batch(self, x, y,

~/.virtualenvs/implicit/lib/python3.5/site-packages/keras/engine/training.py in _predict_loop(self, f, ins, batch_size, verbose)
   1139                 ins_batch = _slice_arrays(ins, batch_ids)
   1140 
-> 1141             batch_outs = f(ins_batch)
   1142             if not isinstance(batch_outs, list):
   1143                 batch_outs = [batch_outs]

~/.virtualenvs/implicit/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2266         updated = session.run(self.outputs + [self.updates_op],
   2267                               feed_dict=feed_dict,
-> 2268                               **self.session_kwargs)
   2269         return updated[:len(self.outputs)]
   2270 

~/.virtualenvs/implicit/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    787     try:
    788       result = self._run(None, fetches, feed_dict, options_ptr,
--> 789                          run_metadata_ptr)
    790       if run_metadata:
    791         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/.virtualenvs/implicit/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    995     if final_fetches or final_targets:
    996       results = self._do_run(handle, final_targets, final_fetches,
--> 997                              feed_dict_string, options, run_metadata)
    998     else:
    999       results = []

~/.virtualenvs/implicit/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1130     if handle is None:
   1131       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1132                            target_list, options, run_metadata)
   1133     else:
   1134       return self._do_call(_prun_fn, self._session, handle, feed_dict,

~/.virtualenvs/implicit/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1137   def _do_call(self, fn, *args):
   1138     try:
-> 1139       return fn(*args)
   1140     except errors.OpError as e:
   1141       message = compat.as_text(e.message)

~/.virtualenvs/implicit/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1119         return tf_session.TF_Run(session, options,
   1120                                  feed_dict, fetch_list, target_list,
-> 1121                                  status, run_metadata)
   1122 
   1123     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 
In [ ]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)

Evaluating the model

In [ ]:
w_sample_prior = prior.rvs(size=128)
eps = np.random.randn(256, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(128), np.ones(256)))
In [ ]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [ ]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax1.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax1.set_xlabel('$w_1$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

ax2.contourf(w1, w2, np.sum(llhs, axis=2), 
             cmap=plt.cm.magma)
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [ ]:
eps = np.random.randn(5000, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
In [ ]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax1.scatter(*inference.predict(eps[::10]).T, 
            s=4.**2, alpha=.6, cmap='coolwarm_r')

ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

sns.kdeplot(*inference.predict(eps).T,
            cmap='magma', ax=ax2)

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [107]:
output = expit(np.random.randn(256))
target = np.hstack((np.zeros(128), np.ones(128)))
In [108]:
2*K.mean(K.binary_crossentropy(output=K.constant(output), 
                      target=K.constant(target))).eval(session=sess)
Out[108]:
1.7183432579040527
In [109]:
np.mean(-np.log(output[128:])-np.log(1-output[:128]))
Out[109]:
1.7183431791525026
In [94]:
(-np.log(output[:128])-np.log(1-output[128:])).shape
Out[94]:
(128,)
In [74]:
p1[:128][0]
Out[74]:
0.343090641899925
In [73]:
p1[128:][0]
Out[73]:
0.50225667608674573
In [113]:
ratio_estimator.get_weights()[0]
Out[113]:
array([[ 0.34,  0.75, -0.56, -0.46,  0.04,  0.25, -0.64, -0.32,  0.05,  0.67],
       [-0.38, -0.27,  0.19,  0.55, -0.63, -0.96, -0.27, -0.75,  0.82, -0.13]], dtype=float32)
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 8)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
Using TensorFlow backend.
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([  .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 139996050319736 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)139996049881464 dense_1: Denseinput:output:(None, 2)(None, 10)139996050319736->139996049881464 139996049882416 dense_2: Denseinput:output:(None, 10)(None, 20)139996049881464->139996049882416 139996050321304 logit: Denseinput:output:(None, 20)(None, 1)139996049882416->139996050321304 139996050741624 activation_1: Activationinput:output:(None, 1)(None, 1)139996050321304->139996050741624
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.62235820293426514, 0.60000002384185791]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 139996049667352 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)139996049443528 dense_3: Denseinput:output:(None, 3)(None, 10)139996049667352->139996049443528 139996049443416 dense_4: Denseinput:output:(None, 10)(None, 20)139996049443528->139996049443416 139996049667520 dense_5: Denseinput:output:(None, 20)(None, 2)139996049443416->139996049667520
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(200, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(200, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
metrics
Out[37]:
[0.69308857917785649, 0.47749999999999998]
In [38]:
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
In [39]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

metrics_plots = {k:ax1.plot([], label=k)[0] 
                 for k in ['loss']} # discriminator.metrics_names}

ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [40]:
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):

    # Single training epoch
    
    for step in tnrange(steps_per_epoch, unit='step', leave=False):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    # Plot Metrics
        
    metrics_dict = dict(zip(discriminator.metrics_names, metrics))

    for metric in metrics_plots:
        metrics_plots[metric].set_xdata(np.append(metrics_plots[metric].get_xdata(), 
                                                  epoch_num))    
        metrics_plots[metric].set_ydata(np.append(metrics_plots[metric].get_ydata(), 
                                                  metrics_dict[metric]))
        metrics_plots[metric].set_label('{} ({:.2f})' \
                                        .format(metric, 
                                                metrics_dict[metric]))
    
    ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
    ax1.legend(loc='upper left')

    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(**metrics_dict)

    return list(metrics_plots.values())
In [41]:
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate` 
# function that encapsulates the logic of single
# training epoch. Has benefit of producing 
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=PRETRAIN_EPOCHS,
                         interval=200, # 5 fps
                         blit=True)

    anim_html5_video = anim.to_html5_video()
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

In [42]:
HTML(anim_html5_video)
Out[42]:
In [43]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [44]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [45]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [46]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

metrics_dict = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**metrics_dict), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [47]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [48]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[48]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [49]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[49]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [50]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [51]:
sess = K.get_session()
In [52]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[52]:
True
In [53]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [54]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [55]:
elbo = make_elbo(ratio_estimator)
In [56]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [57]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [58]:
inference.compile(loss=inference_loss, 
                  optimizer=Adam(lr=LEARNING_RATE))
In [59]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [60]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[60]:
<tf.Tensor 'concat:0' shape=(200, 3) dtype=float32>
In [61]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[61]:
-2.5127466
In [62]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/200 [===>..........................] - ETA: 0s
Out[62]:
2.5127464294433595

Adversarial Training

In [63]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

global_epoch = 0

loss_plot_inference, = ax1.plot([], label='inference')
loss_plot_discrim, = ax1.plot([], label='discriminator')

ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [64]:
def train_animate(epoch_num, prog_bar, batch_size=200, 
                  steps_per_epoch=15):

    global global_epoch, loss_plot_inference, loss_plot_discrim
    
    # Single training epoch

    ## Ratio estimator training
        
    set_trainable(discriminator, True)

    for _ in tnrange(3*50, unit='step', desc='discriminator', 
                     leave=False):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)

    metrics_dict_discrim = dict(zip(discriminator.metrics_names, 
                                    np.atleast_1d(metrics_discrim)))
    
    ## Inference model training
    
    set_trainable(ratio_estimator, False)

    y_tiled = np.tile(y, reps=(BATCH_SIZE, 1))

    for _ in tnrange(1, unit='step', desc='inference', leave=False):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, y_tiled)
        
    metrics_dict_inference = dict(zip(inference.metrics_names, 
                                      np.atleast_1d(metrics_inference)))

    global_epoch += 1
    
    # Plot Loss
 
    loss_plot_inference.set_xdata(np.append(loss_plot_inference.get_xdata(),
                                            global_epoch))
    loss_plot_inference.set_ydata(np.append(loss_plot_inference.get_ydata(), 
                                            metrics_dict_inference['loss']))

    loss_plot_inference.set_label('inference ({:.2f})' \
                                  .format(metrics_dict_inference['loss']))

    loss_plot_discrim.set_xdata(np.append(loss_plot_discrim.get_xdata(),
                                          global_epoch))
    loss_plot_discrim.set_ydata(np.append(loss_plot_discrim.get_ydata(),
                                          metrics_dict_discrim['loss']))

    loss_plot_discrim.set_label('discriminator ({:.2f})' \
                                  .format(metrics_dict_discrim['loss']))
    
    ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
    ax1.legend(loc='upper left')
    
    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(loss_inference=metrics_dict_inference['loss'],
                         loss_discriminator=metrics_dict_discrim['loss'])

    return loss_plot_inference, loss_plot_discrim
In [65]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[65]:
In [66]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[66]:
In [67]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[67]:
In [68]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[68]:

Evaluating the model

In [81]:
w_sample_prior = prior.rvs(size=128)
eps = np.random.randn(256, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(128), np.ones(256)))
In [82]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [85]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax1.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax1.set_xlabel('$w_1$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

ax2.contourf(w1, w2, np.sum(llhs, axis=2), 
             cmap=plt.cm.magma)
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [101]:
eps = np.random.randn(5000, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
In [105]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax1.scatter(*inference.predict(eps[::10]).T, 
            s=4.**2, alpha=.6, cmap='coolwarm_r')

ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

sns.kdeplot(*inference.predict(eps).T,
            cmap='magma', ax=ax2)

plt.show()

Variational Inference with Implicit Models (forked from @fhuszar)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import theano

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.special import expit
from scipy.stats import logistic

from theano import tensor as T
from theano.tensor.shared_randomstreams import RandomStreams
from theano.printing import debugprint

from lasagne.updates import adam
from lasagne.utils import floatX
from lasagne.nonlinearities import sigmoid
from lasagne.layers import get_output, get_all_params
from lasagne.layers import (InputLayer,
                            DenseLayer,
                            NonlinearityLayer)


from matplotlib.animation import FuncAnimation
from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
/home/tiao/.virtualenvs/implicit/lib/python3.5/site-packages/theano/tensor/signal/downsample.py:6: UserWarning: downsample module has been moved to the theano.tensor.signal.pool module.
  "downsample module has been moved to the theano.tensor.signal.pool module.")
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60
In [6]:
PRETRAIN_EPOCHS*15
Out[6]:
900
In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
log_prior = -.5*np.sum(w_grid**2, axis=2)/PRIOR_VARIANCE
log_prior.shape
Out[10]:
(300, 300)
In [11]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [12]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([  .5, -1.])
In [13]:
X = np.vstack((x1, x2, x3))
X.shape
Out[13]:
(3, 2)
In [14]:
y1 = 1
y2 = 1
y3 = -1
In [15]:
y = np.stack((y1, y2, y3))
y.shape
Out[15]:
(3,)
In [16]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return logistic.logcdf(y*(np.dot(w.T,x)))
In [17]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[17]:
(300, 300, 3)
In [18]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [19]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [20]:
# unnormalised log posterior
# only for illustration purposes
log_post = log_prior + np.sum(llhs, axis=2)
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_post), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Fitting an approximate posterior

This part is for the actual GAN stuff. Here we define the generator and the discriminator networks in Lasagne, and code up the two loss functions in theano.

In [22]:
#defines a 'generator' network
def build_G(input_var=None, num_z = 3):
    
    network = InputLayer(input_var=input_var, shape=(None, num_z))
    
    network = DenseLayer(incoming = network, num_units=10)
    
    network = DenseLayer(incoming = network, num_units=20)
    
    network = DenseLayer(incoming = network, num_units=2, nonlinearity=None)
    
    return network
In [23]:
#defines the 'discriminator network'
def build_D(input_var=None):

    network = InputLayer(input_var=input_var, shape = (None, 2))
    
    network = DenseLayer(incoming = network, num_units=10)
    
    network = DenseLayer(incoming = network, num_units=20)
    
    network = DenseLayer(incoming = network, num_units=1, nonlinearity=None)
    
    normalised = NonlinearityLayer(incoming = network, nonlinearity = sigmoid)
    
    return { 'unnorm':network, 'norm':normalised }
In [71]:
#variables for input (design matrix), output labels, GAN noise variable, weights
x_var = T.matrix('design matrix')
y_var = T.vector('labels')
z_var = T.matrix('GAN noise')
w_var = T.matrix('weights')

#theano variables for things like batchsize, learning rate, etc.
batchsize_var = T.scalar('batchsize', dtype='int32')
prior_variance_var = T.scalar('prior variance')
learningrate_var = T.scalar('learning rate')

#random numbers for sampling from the prior or from the GAN
srng = RandomStreams(seed=13574437)
z_rnd = srng.normal((batchsize_var,3))
prior_rnd = srng.normal((batchsize_var,2))

#instantiating the G and D networks
generator = build_G(z_var)
discriminator = build_D()

#these expressions are random samples from the generator and the prior, respectively
samples_from_grenerator = get_output(generator, z_rnd)
samples_from_prior = prior_rnd*T.sqrt(prior_variance_var)

#discriminator output for synthetic samples, both normalised and unnormalised (after/before sigmoid)
D_of_G = get_output(discriminator['norm'], inputs=samples_from_grenerator)
s_of_G = get_output(discriminator['unnorm'], inputs=samples_from_grenerator)

#discriminator output for real samples from the prior
D_of_prior = get_output(discriminator['norm'], inputs=samples_from_prior)

#loss of discriminator - simple binary cross-entropy loss
loss_D = -T.log(D_of_G).mean() - T.log(1-D_of_prior).mean()

#log likelihood for each synthetic w sampled from the generator
log_likelihood = T.log(
    T.nnet.sigmoid(
        (y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * samples_from_grenerator.dimshuffle('x', 1, 0))).sum(1)
    )
).sum(0).mean()

#loss for G is the sum of unnormalised discriminator output and the negative log likelihood
loss_G = s_of_G.mean() - log_likelihood

#compiling theano functions:
evaluate_generator = theano.function(
    [z_var],
    get_output(generator),
    allow_input_downcast=True
)

sample_generator = theano.function(
    [batchsize_var],
    samples_from_grenerator,
    allow_input_downcast=True,
)

sample_prior = theano.function(
    [prior_variance_var, batchsize_var],
    samples_from_prior,
    allow_input_downcast=True
)

params_D = get_all_params(discriminator['norm'], trainable=True)

updates_D = adam(
    loss_D,
    params_D,
    learning_rate = learningrate_var
)

train_D = theano.function(
    [learningrate_var, batchsize_var, prior_variance_var],
    loss_D,
    updates = updates_D,
    allow_input_downcast = True
)

params_G = get_all_params(generator, trainable=True)

updates_G = adam(
    loss_G,
    params_G,
    learning_rate = learningrate_var
)

train_G = theano.function(
    [x_var, y_var, learningrate_var, batchsize_var],
    loss_G,
    updates = updates_G,
    allow_input_downcast = True
)

evaluate_discriminator = theano.function(
    [w_var],
    get_output([discriminator['unnorm'],discriminator['norm']],w_var),
    allow_input_downcast = True
)

#this is to evaluate the log-likelihood of an arbitrary set of w
llh_for_w = T.nnet.sigmoid((y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * w_var.dimshuffle('x', 1, 0))).sum(1))

evaluate_loglikelihood = theano.function(
        [x_var, y_var, w_var],
        llh_for_w,
        allow_input_downcast = True
    )
In [72]:
w_grid_ratio, _ = evaluate_discriminator(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300,300)
In [73]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [74]:
w_sample_prior = sample_prior(PRIOR_VARIANCE, BATCH_SIZE)
w_sample_posterior = sample_generator(BATCH_SIZE)
In [75]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [76]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.exp(log_post), cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [77]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

loss_plot, = ax1.plot([], label='loss')

ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [78]:
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):

    # Single training epoch
    
    for step in tnrange(steps_per_epoch, unit='step', leave=False):
        
        loss = np.asscalar(train_D(LEARNING_RATE, 
                                   batch_size, 
                                   PRIOR_VARIANCE))

    # Plot Loss

    loss_plot.set_xdata(np.append(loss_plot.get_xdata(), epoch_num))    
    loss_plot.set_ydata(np.append(loss_plot.get_ydata(), loss))
    loss_plot.set_label('loss ({:.2f})'.format(loss))
    
    ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
    ax1.legend(loc='upper left')

    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()
        
    w_grid_ratio, _ = evaluate_discriminator(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300,300)
    
    w_sample_prior = sample_prior(PRIOR_VARIANCE, batch_size)
    w_sample_posterior = sample_generator(batch_size)

    inputs = np.vstack((w_sample_prior, w_sample_posterior))
    targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
    
    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(loss=loss)

    return loss_plot,
In [79]:
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate` 
# function that encapsulates the logic of single
# training epoch. Has benefit of producing 
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=PRETRAIN_EPOCHS,
                         interval=200, # 5 fps
                         blit=True)

    anim_html5_video = anim.to_html5_video()
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

In [80]:
HTML(anim_html5_video)
Out[80]:
In [81]:
fig, ax = plt.subplots(figsize=(5, 5))

w_grid_ratio, _ = evaluate_discriminator(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300,300)

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Adversarial Training

In [82]:
llh_theano = evaluate_loglikelihood(X, y, w_grid.reshape(300*300, 2))
llh_theano.shape
Out[82]:
(3, 90000)
In [83]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, np.sum(llhs, axis=2), 
             cmap='magma')

ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
ax1.set_title('numpy loglikelihood')

ax2.contourf(w1, w2, np.sum(np.log(llh_theano), 0).reshape(300,300), 
             cmap='magma')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
ax2.set_title('theano loglikelihood')

plt.show()
In [84]:
np.allclose(np.sum(llhs, axis=2),
            np.sum(np.log(llh_theano), 0).reshape(300,300))
Out[84]:
True
In [85]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

global_epoch = 0

loss_plot_inference, = ax1.plot([], label='inference')
loss_plot_discrim, = ax1.plot([], label='discriminator')

ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [86]:
def train_animate(epoch_num, prog_bar, batch_size=200, 
                  steps_per_epoch=15):

    global global_epoch, loss_plot_inference, loss_plot_discrim
    
    # Single training epoch

    ## Ratio estimator training

    for _ in tnrange(150, unit='step', desc='discriminator', 
                     leave=False):

        loss_d = np.asscalar(train_D(LEARNING_RATE, 
                                     batch_size, 
                                     PRIOR_VARIANCE))
    
    ## Inference model training

    for _ in tnrange(1, unit='step', desc='inference', leave=False):

        loss_g = np.asscalar(train_G(X, y, LEARNING_RATE, batch_size))

    global_epoch += 1

    # Plot Loss

    loss_plot_discrim.set_xdata(np.append(loss_plot_discrim.get_xdata(), global_epoch))
    loss_plot_discrim.set_ydata(np.append(loss_plot_discrim.get_ydata(), loss_d))
    loss_plot_discrim.set_label('discriminator ({:.2f})'.format(loss_d))
    
    loss_plot_inference.set_xdata(np.append(loss_plot_inference.get_xdata(), global_epoch))    
    loss_plot_inference.set_ydata(np.append(loss_plot_inference.get_ydata(), loss_g))
    loss_plot_inference.set_label('inference ({:.2f})'.format(loss_g))

    ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
    ax1.legend(loc='upper left')
    
    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()
        
    w_grid_ratio, _ = evaluate_discriminator(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300,300)
    
    w_sample_prior = sample_prior(PRIOR_VARIANCE, batch_size)
    w_sample_posterior = sample_generator(batch_size)

    inputs = np.vstack((w_sample_prior, w_sample_posterior))
    targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
    
    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)

    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(loss_g=loss_g, loss_d=loss_d)

    return loss_plot_inference, loss_plot_discrim
In [87]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[87]:
In [88]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[88]:
In [89]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[89]:
In [90]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[90]:
In [ ]:
 
In [45]:
plt.figure(figsize=(16,8))


plt.subplot(1,2,1)
plt.contourf(wrange, wrange, evaluate_discriminator(w.T)[0].reshape(300,300).T)

W = sample_generator(100)
plt.plot(W[:,0],W[:,1],'g.')

W = sample_prior(prior_variance, 100)
plt.plot(W[:,0],W[:,1],'r.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('estimated log density ratio $\Phi^{-1}(D)$')

plt.subplot(1,2,2)
plt.contourf(wrange, wrange, (llh1+llh2+llh3).reshape(300,300).T)
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);

plt.title('log likelihood');
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-45-00c7d0990f48> in <module>()
      3 
      4 plt.subplot(1,2,1)
----> 5 plt.contourf(wrange, wrange, evaluate_discriminator(w.T)[0].reshape(300,300).T)
      6 
      7 W = sample_generator(100)

NameError: name 'wrange' is not defined
In [ ]:
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.axis('square');
plt.xlim([w_min,w_max])
plt.ylim([w_min,w_max])
plt.title('theano loglikelihood')

plt.subplot(1,2,2)
plt.contourf(w1, w2, (llh1+llh2+llh3).reshape(300,300).T,cmap='gray');
plt.axis('square');
plt.xlim([w_min,w_max])
plt.ylim([w_min,w_max])
plt.title('numpy loglikelihood')

assert np.allclose(llh1+llh2+llh3,np.log(llh_theano).sum(0))
In [ ]:
plt.figure(figsize=(16,8))


plt.subplot(1,2,1)
plt.contourf(wrange, wrange, evaluate_discriminator(w.T)[0].reshape(300,300).T)

W = sample_generator(100)
plt.plot(W[:,0],W[:,1],'g.')

W = sample_prior(prior_variance, 100)
plt.plot(W[:,0],W[:,1],'r.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('estimated log density ratio $\Phi^{-1}(D)$')

plt.subplot(1,2,2)
plt.contourf(wrange, wrange, (llh1+llh2+llh3).reshape(300,300).T)
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);

plt.title('log likelihood');
In [ ]:
X_train = np.concatenate([x1,x2,x3],axis=1).T
y_train = np.array([y1,y2,y3])
llh_theano = evaluate_loglikelihood(X_train, y_train, w.T)

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.contourf(wrange, wrange ,np.log(llh_theano).sum(0).reshape(300,300).T,cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('theano loglikelihood')

plt.subplot(1,2,2)
plt.contourf(wrange, wrange, (llh1+llh2+llh3).reshape(300,300).T,cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('numpy loglikelihood')

assert np.allclose(llh1+llh2+llh3,np.log(llh_theano).sum(0))
In [ ]:
plt.contourf(wrange, wrange, np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
W = sample_generator(1000)
plt.plot(W[:,0],W[:,1],'.g')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.title('true log posterior');
In [ ]:
sns.set_style('whitegrid')
plt.subplot(1,2,2)

W = sample_generator(5000)
plot = sns.kdeplot(W[:,0],W[:,1])
plt.axis('square')
plot.set(xlim=(wmin,wmax))
plot.set(ylim=(wmin,wmax))
plt.title('kde of approximate posterior')

plt.subplot(1,2,1)
plt.contourf(wrange, wrange, np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
W = sample_generator(100)
plt.plot(W[:,0],W[:,1],'.g')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.title('true posterior');
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 7)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
from tqdm import tnrange
Using TensorFlow backend.
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([  .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 4729044328 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)4727278616 dense_1: Denseinput:output:(None, 2)(None, 10)4729044328->4727278616 4729043432 dense_2: Denseinput:output:(None, 10)(None, 20)4727278616->4729043432 4729057176 logit: Denseinput:output:(None, 20)(None, 1)4729043432->4729057176 4721891144 activation_1: Activationinput:output:(None, 1)(None, 1)4729057176->4721891144
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.53143018484115601, 1.0]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 4732353784 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)4732435816 dense_3: Denseinput:output:(None, 3)(None, 10)4732353784->4732435816 4732434640 dense_4: Denseinput:output:(None, 10)(None, 20)4732435816->4732434640 4716115280 dense_5: Denseinput:output:(None, 20)(None, 2)4732434640->4716115280
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(200, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(200, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
Out[37]:
(-5, 5)
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=200, steps_per_epoch=15):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=60, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [44]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [45]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[45]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [46]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[46]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [47]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [48]:
sess = K.get_session()
In [49]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[49]:
True
In [50]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [51]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [52]:
elbo = make_elbo(ratio_estimator)
In [53]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [54]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [55]:
inference.compile(loss=inference_loss, 
                  optimizer=Adam(lr=LEARNING_RATE))
In [56]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [57]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[57]:
<tf.Tensor 'concat:0' shape=(200, 3) dtype=float32>
In [58]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[58]:
-3.8279114
In [59]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/200 [===>..........................] - ETA: 0s
Out[59]:
3.8279113578796387

Training

In [60]:
for epoch in tnrange(200, desc='epoch'):

    set_trainable(ratio_estimator, False)

    for _ in tnrange(1, desc='generator'):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, np.tile(y, reps=(BATCH_SIZE, 1)))

    set_trainable(discriminator, True)

    for _ in tnrange(3*50, desc='discriminator'):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)

In [61]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [62]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [63]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [64]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.gray)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

My PhD in Weeks

This is your PhD, and it's ending one week at a time.
  1 2 3 4
1        
       
       
       
5        
       
       
       
       
10        
       
       
       
       
15        
       
       
       
       
20        
       
       
       
       
25        
       
      x
      x
      x
30       x
      x
      x
      x
      x
35       x
      x
      x
      x
      x
40       x
      x
      x
      x
      x
45       x
      x
      x
      x
      x
50       x
      x
      x

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 6)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
In [3]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 4766144608 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)4766145224 dense_1: Denseinput:output:(None, 2)(None, 10)4766144608->4766145224 4767336208 dense_2: Denseinput:output:(None, 10)(None, 20)4766145224->4767336208 4766143992 logit: Denseinput:output:(None, 20)(None, 1)4767336208->4766143992 4766535464 activation_1: Activationinput:output:(None, 1)(None, 1)4766143992->4766535464
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.72363483905792236, 0.40000000596046448]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 4769765752 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)4769787520 dense_3: Denseinput:output:(None, 3)(None, 10)4769765752->4769787520 4769190184 dense_4: Denseinput:output:(None, 10)(None, 20)4769787520->4769190184 4769949288 dense_5: Denseinput:output:(None, 20)(None, 2)4769190184->4769949288
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(128, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(128, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [44]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [45]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[45]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [46]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[46]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [47]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [48]:
sess = K.get_session()
In [49]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[49]:
True
In [50]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [51]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [52]:
elbo = make_elbo(ratio_estimator)
In [53]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [59]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [60]:
inference.compile(loss=inference_loss, 
                  optimizer='adam')
In [61]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [62]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[62]:
<tf.Tensor 'concat_1:0' shape=(128, 3) dtype=float32>
In [63]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[63]:
-3.9920437
In [64]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/128 [======>.......................] - ETA: 0s
Out[64]:
3.9920437335968018

Training

In [70]:
for epoch in range(3*200):

    set_trainable(ratio_estimator, False)

    for _ in range(1):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, np.tile(y, reps=(BATCH_SIZE, 1)))

    set_trainable(discriminator, True)

    for _ in range(3*50):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)
In [71]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [72]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [73]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [74]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 5)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
In [3]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 4796441544 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)4797073168 dense_1: Denseinput:output:(None, 2)(None, 10)4796441544->4797073168 4796631920 dense_2: Denseinput:output:(None, 10)(None, 20)4797073168->4796631920 4796443560 logit: Denseinput:output:(None, 20)(None, 1)4796631920->4796443560 4795146192 activation_1: Activationinput:output:(None, 1)(None, 1)4796443560->4795146192
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.63304531574249268, 0.60000002384185791]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 4783101656 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)4797501680 dense_3: Denseinput:output:(None, 3)(None, 10)4783101656->4797501680 4797502688 dense_4: Denseinput:output:(None, 10)(None, 20)4797501680->4797502688 4799943456 dense_5: Denseinput:output:(None, 20)(None, 2)4797502688->4799943456
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(128, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(128, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Inference Model Training

In [44]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[44]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [45]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[45]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [46]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [47]:
sess = K.get_session()
In [48]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[48]:
True
In [49]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [50]:
def make_elbo(ratio_estimator):

    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [51]:
elbo = make_elbo(ratio_estimator)
In [52]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [94]:
inference.compile(loss=make_elbo(ratio_estimator), 
                  optimizer='adam')
In [95]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [96]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/128 [======>.......................] - ETA: 0s
Out[96]:
-3.2424654364585876
In [105]:
# equiv. to use of tile above
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[105]:
<tf.Tensor 'concat_10:0' shape=(128, 3) dtype=float32>
In [107]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[107]:
-3.2424655

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 4)

In [177]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [178]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
In [179]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [180]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [181]:
K.tf.__version__
Out[181]:
'1.2.1'
In [182]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [183]:
w_min, w_max = -5, 5
In [184]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [185]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[185]:
(300, 300, 2)
In [186]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [187]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[187]:
(300, 300)
In [188]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [189]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [190]:
X = np.vstack((x1, x2, x3))
X.shape
Out[190]:
(3, 2)
In [191]:
y1 = 1
y2 = 1
y3 = 0
In [192]:
y = np.stack((y1, y2, y3))
y.shape
Out[192]:
(3,)
In [193]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [194]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[194]:
(300, 300, 3)
In [195]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [196]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [197]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 140297740092864 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)140297740092640 dense_1: Denseinput:output:(None, 2)(None, 10)140297740092864->140297740092640 140297740091520 dense_2: Denseinput:output:(None, 10)(None, 20)140297740092640->140297740091520 140297740093312 logit: Denseinput:output:(None, 20)(None, 1)140297740091520->140297740093312 140297740622592 activation_1: Activationinput:output:(None, 1)(None, 1)140297740093312->140297740622592
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.63745558261871338, 0.80000001192092896]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 140297740961440 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)140297740025472 dense_3: Denseinput:output:(None, 3)(None, 10)140297740961440->140297740025472 140297740023568 dense_4: Denseinput:output:(None, 10)(None, 20)140297740025472->140297740023568 140297740591512 dense_5: Denseinput:output:(None, 20)(None, 2)140297740023568->140297740591512
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(128, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(128, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Inference Model Training

In [214]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[214]:
<tf.Tensor 'Sigmoid_13:0' shape=(300, 300, 3) dtype=float32>
In [215]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[215]:
<tf.Tensor 'mul_50:0' shape=(300, 300, 3) dtype=float32>
In [207]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [208]:
sess = K.get_session()
In [209]:
np.sum(llhs, axis=-1)
Out[209]:
array([[-20.08, -20.01, -19.94, ...,  -2.72,  -2.69,  -2.66],
       [-20.02, -19.95, -19.88, ...,  -2.68,  -2.64,  -2.61],
       [-19.95, -19.88, -19.81, ...,  -2.63,  -2.6 ,  -2.56],
       ..., 
       [-15.1 , -15.03, -14.96, ...,  -2.55,  -2.52,  -2.49],
       [-15.13, -15.06, -14.99, ...,  -2.59,  -2.56,  -2.53],
       [-15.16, -15.09, -15.02, ...,  -2.64,  -2.61,  -2.58]])
In [210]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[210]:
True
In [211]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [216]:
def make_elbo(ratio_estimator):

    def elbo(y_true, y_pred):
        log_likelihood = - K.binary_crossentropy(y_pred, y_true)
        kl_estimate = K.mean(ratio_estimator(y_pred), axis=-1)
        return log_likelihood - kl_estimate
    
    return elbo
In [217]:
elbo = make_elbo(ratio_estimator)
In [ ]: